import sys
import json
from common.log import log_handler
from sklearn import metrics
import numpy as np
from sklearn.preprocessing import OneHotEncoder
import copy
from common.utils.format_util import java_to_py, py_to_java
from common.database import OLTP
from common.log import log_handler

log = log_handler.LogHandler().get_log()

try:
    # autosklearn运行过程中需要引用以下这些包
    from autosklearn.pipeline.regression import SimpleRegressionPipeline
    from autosklearn.pipeline.classification import SimpleClassificationPipeline
    from sklearn.dummy import DummyClassifier, DummyRegressor
    from autosklearn.evaluation.abstract_evaluator import MyDummyClassifier, MyDummyRegressor
except ImportError:
    log.error("haven't install auto-sklearn")

sys.path.append("../..")


def get_model_info_sklearn(model_id):
    series = OLTP.execute_query("select * from model where id = '{}'".format(model_id)).iloc[-1]

    model_path = series["model_saved_path"].split("ml-model/")[1]

    other_info = eval(series["other_info"].replace("false", "False").replace("true", "True").replace("null", "None"))

    if "label_list" in other_info:
        label_list = other_info["label_list"]
    else:
        label_list = []
    feature_list = other_info["feature_list"]
    return model_path, label_list, feature_list


def eval_metric(y_true, y_pred, compute_confusion_matrix=0):
    accuracy = metrics.accuracy_score(y_true, y_pred)
    recall = metrics.recall_score(y_true, y_pred, average="macro")
    precision = metrics.precision_score(y_true, y_pred, average="macro")
    F1 = metrics.f1_score(y_true, y_pred, average="macro")
    if compute_confusion_matrix:
        confusion_matrix = metrics.confusion_matrix(y_true, y_pred)
    else:
        confusion_matrix = []
    return accuracy, precision, recall, F1, confusion_matrix


def preprocess(dataframe, para_list, feature_list):
    separator = "_OneHot/"

    if para_list["execution"] == "predict":
        feature_X = para_list["feature_X"].split(",")
    else:
        feature_X = para_list["feature_col"]
    cols = np.array(feature_X)

    num_cols, str_cols = [], []
    one_data = dataframe[cols].values[0]
    for i in range(cols.size):
        if type(one_data[i]) == str:
            str_cols.append(feature_X[i])
        else:
            num_cols.append(feature_X[i])
    # for col, data in zip(cols, one_data):
    #     log.info(col)
    #     log.info(data)
    #     if type(data) == str:
    #         str_cols.append(col)
    #     else:
    #         num_cols.append(col)

    num_data = dataframe[num_cols].values
    if feature_list == []:
        feature_list += num_cols
        if len(str_cols) > 0:
            str_data = dataframe[str_cols].values
            encoder = OneHotEncoder()
            encoded_str_data = encoder.fit_transform(str_data).toarray()
            encoded_data = np.concatenate((num_data, encoded_str_data), axis=1)
            categories = encoder.categories_
            for col, enums in zip(str_cols, categories):
                for enum in enums:
                    feature_list.append(col + separator + enum)
        else:
            encoded_data = num_data
    else:
        if len(str_cols) > 0:
            str_data = dataframe[str_cols].values
            str_feature_list = feature_list[len(num_cols):]
            categories = [[] for i in range(len(str_cols))]
            features = []
            index = -1
            for str_feature_col in str_feature_list:
                col, enum = str_feature_col.split(separator)
                if not col in features:
                    features.append(col)
                    index += 1
                categories[index].append(enum)
            encoder = OneHotEncoder(categories)
            encoded_str_data = encoder.fit_transform(str_data).toarray()
            encoded_data = np.concatenate((num_data, encoded_str_data), axis=1)
        else:
            encoded_data = num_data
    return encoded_data


def semantic_filling(new_columns, data_json, para_list):
    semantic_new = {}
    # TODO
    # semantic_old = copy.deepcopy(data_json["input"][0]["semantic"])
    # for key,value in semantic_old.items():
    #     new_key = key.lower().replace(" ","_").replace("/","_").replace("-","_")
    #     semantic_new[new_key] = value
    data_json["output"][0]["semantic"] = semantic_new
    log.info(data_json["output"])
    algo = para_list["algo"]
    if algo in ["KM"]:
        for new_column in new_columns:
            data_json["output"][0]["semantic"][new_column] = "disorder"
            data_json["output"][0]["columnTypes"].append("BIGINT")
    elif algo in ["PCA", "AUTOR"]:
        for new_column in new_columns:
            data_json["output"][0]["semantic"][new_column] = "null"
            data_json["output"][0]["columnTypes"].append("FLOAT")
    elif algo in ["DTC", "AUTOC"]:
        if len(data_json["output"][0]["semantic"].values()) != 0:
            for new_column in new_columns:
                data_json["output"][0]["semantic"][new_column] = "null"
                # data_json["output"][0]["semantic"][para_list["label_col"]]
        # index = data_json["output"][0]["tableCols"].index(para_list["label_col"])
        data_json["output"][0]["columnTypes"].append("VARCHAR")
    elif algo in ["LR"]:
        if len(data_json["output"][0]["semantic"].values()) != 0:
            for new_column in new_columns:
                data_json["output"][0]["semantic"][new_column] = "null"
                # data_json["output"][0]["semantic"][para_list["label_col"]]
        # index = data_json["output"][0]["tableCols"].index(para_list["label_col"])
        data_json["output"][0]["columnTypes"].append("FLOAT")
    log.info("UPDATED DATA JSON OUTPUT:")
    log.info(data_json["output"])
    return data_json


def update_task_when_pred(task, df_pred, para_list, target):
    columns_df_pred = df_pred.columns.to_list()
    log.info(java_to_py(task.data_json))
    data_json = json.loads(java_to_py(task.data_json))
    data_json["output"] = copy.deepcopy(data_json["input"])
    new_columns = []
    old_columns = data_json["output"][0]["tableCols"]
    for new_column in columns_df_pred:
        if new_column not in old_columns:
            new_columns.append(new_column)
    data_json["output"][0]["tableCols"] = old_columns + new_columns
    semantic_filling(new_columns, data_json, para_list)
    log.info("SEMANTIC FILLED")
    data_json["output"][0]["tableName"] = target
    data = py_to_java(str(data_json))
    # log.info(data_json)
    # data_json = py_to_java(data_json)
    task.data_json = data
    # saveMetaForMysql(output, instance_id)
    return task


def semantic_filling(new_columns, data_json, para_list):
    semantic_new = {}
    # TODO
    # semantic_old = copy.deepcopy(data_json["input"][0]["semantic"])
    # for key,value in semantic_old.items():
    #     new_key = key.lower().replace(" ","_").replace("/","_").replace("-","_")
    #     semantic_new[new_key] = value
    data_json["output"][0]["semantic"] = semantic_new
    log.info(data_json["output"])
    algo = para_list["algo"]
    if algo in ["KM"]:
        for new_column in new_columns:
            data_json["output"][0]["semantic"][new_column] = "disorder"
            data_json["output"][0]["columnTypes"].append("BIGINT")
    elif algo in ["PCA"]:
        for new_column in new_columns:
            data_json["output"][0]["semantic"][new_column] = "null"
            data_json["output"][0]["columnTypes"].append("FLOAT")
    elif algo in ["DTC", "AUTOC"]:
        if len(data_json["output"][0]["semantic"].values()) != 0:
            for new_column in new_columns:
                data_json["output"][0]["semantic"][new_column] = "null"
                # data_json["output"][0]["semantic"][para_list["label_col"]]
        # index = data_json["output"][0]["tableCols"].index(para_list["label_col"])
        data_json["output"][0]["columnTypes"].append("VARCHAR")
    elif algo in ["LR", "AUTOR"]:
        if len(data_json["output"][0]["semantic"].values()) != 0:
            for new_column in new_columns:
                data_json["output"][0]["semantic"][new_column] = "null"
                # data_json["output"][0]["semantic"][para_list["label_col"]]
        # index = data_json["output"][0]["tableCols"].index(para_list["label_col"])
        data_json["output"][0]["columnTypes"].append("FLOAT")
    log.info("UPDATED DATA JSON OUTPUT:")
    log.info(data_json["output"])
    return data_json


def get_model_info_auto_sklearn(model, type):
    models = eval(model.show_models())
    rets = []
    for i, model_info in enumerate(models):
        if model_info[1].config == 1:
            return "没有找到合适的模型，请延长最大训练时长或者重新选择目标列或者重新选择候选模型"
        config = model_info[1].config.get_dictionary()
        sort_config = json.loads(json.dumps(config, sort_keys=True))

        ret = {}
        ret["id"] = i
        data_preprocessing = {}
        feature_preprocessor = {}
        model_parameters = {}
        model = sort_config[type + ":__choice__"]
        ret["model"] = model

        weight = model_info[0]
        ret["weight"] = weight

        for key in sort_config:
            new_key = ":".join(key.split(":")[2:])
            if key.startswith("data_preprocessing"):
                data_preprocessing[new_key] = sort_config[key]
            if key.startswith("feature_preprocessor"):
                if new_key == "":
                    new_key = "__choice__"
                feature_preprocessor[new_key] = sort_config[key]
            if key.startswith(type):
                if new_key:
                    model_parameters[new_key] = sort_config[key]
        ret["data_preprocessing"] = data_preprocessing
        ret["feature_preprocessor"] = feature_preprocessor
        ret["model_parameters"] = model_parameters
        rets.append(ret)
    return rets
